package com.whoyx.jiebing.utils.nlp.sentence_encoder;

import ai.djl.Device;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import com.whoyx.jiebing.config.NlpFilePathConfig;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * @author stone
 */
@Slf4j
@Component
@RequiredArgsConstructor
public final class SentenceEncoder {

    private final NlpFilePathConfig pathConfig;

    public Criteria<String, float[]> criteria() {
        Path modelPath = Paths.get(pathConfig.getSentenceEncoderPath());
        return Criteria.builder()
                .setTypes(String.class, float[].class)
                .optModelPath(modelPath)
                .optTranslator(new SentenceTransTranslator())
                .optEngine("PyTorch")
                .optDevice(Device.cpu())
                .optProgress(new ProgressBar())
                .build();
    }

}
